from collections import defaultdict
from dataclasses import asdict

import pytest

import tests.functional.services.policy_engine.utils.api as policy_engine_api
from tests.functional.services.policy_engine.conftest import is_legacy_provider
from tests.functional.services.policy_engine.utils.utils import (
    VulnerabilityQuery,
    VulnerabilityQueryMetadata,
)
from tests.functional.services.utils import http_utils


def _ordered(obj):
    if isinstance(obj, dict):
        return sorted((k, _ordered(v)) for k, v in obj.items())
    if isinstance(obj, list):
        return sorted(_ordered(x) for x in obj)
    else:
        return obj


def sort_nvd_data(elem):
    key = [elem["id"]]
    if elem.get("cvss_v2"):
        key.append(elem["cvss_v2"]["vector_string"])
    else:
        key.append("No cvss_v2")
    return key


nvd_namespace = "nvdv2:cves" if is_legacy_provider() else "nvd"

# Series of tests that verify expected vulnerability query results based upon set of vulns seeded into database.
class TestQueryVulnerabilities:
    @pytest.mark.parametrize(
        "query",
        [
            VulnerabilityQuery(["CVE-2017-7245"], "single_cve"),
            VulnerabilityQuery(
                ["CVE-2017-7245", "CVE-2014-4617", "CVE-2018-5709"], "multiple_cves"
            ),
            VulnerabilityQuery(
                ["CVE-2017-7245"],
                "single_cve_filter_affected_package",
                VulnerabilityQueryMetadata(affected_package="pcre3"),
            ),
            VulnerabilityQuery(
                ["CVE-2017-7245", "CVE-2017-11164"],
                "multiple_cves_filter_affected_package",
                VulnerabilityQueryMetadata(affected_package="pcre3"),
            ),
            VulnerabilityQuery(
                ["CVE-2017-7245"],
                "single_cve_multiple_filters",
                VulnerabilityQueryMetadata(
                    affected_package="pcre3", namespace="debian:10"
                ),
            ),
            VulnerabilityQuery(
                ["CVE-2017-18018"],
                "single_cve_multiple_filters_2",
                VulnerabilityQueryMetadata(
                    affected_package="coreutils",
                    namespace=nvd_namespace,
                    affected_package_version="8.9",
                ),
            ),
            VulnerabilityQuery(
                ["CVE-2017-7245", "CVE-2017-11164"],
                "multiple_cves_multiple_filters",
                VulnerabilityQueryMetadata(
                    affected_package="pcre3", namespace="debian:10"
                ),
            ),
        ],
    )
    def test_query_vulnerabilities(self, query, expected_content, is_legacy_test):
        vulnerabilities_resp = (
            policy_engine_api.query_vulnerabilities.get_vulnerabilities(
                query.id, **asdict(query.query_metadata)
            )
        )

        assert vulnerabilities_resp == http_utils.APIResponse(200)
        assert len(vulnerabilities_resp.body) > 0

        for vuln in vulnerabilities_resp.body:
            assert vuln["id"] in query.id

            if query.query_metadata.namespace:
                assert vuln["namespace"] == query.query_metadata.namespace

            # Validate that every entry in affected package has staging_dummy set
            for package in vuln["affected_packages"]:
                assert "will_not_fix" in package
                assert isinstance(package["will_not_fix"], bool)

            if is_legacy_test and query.query_metadata.affected_package:
                # build dict where key is name and value is array of affected versions
                package_versions = defaultdict(lambda: [])
                for package in vuln["affected_packages"]:
                    package_versions[package["name"]].append(package["version"])

                assert query.query_metadata.affected_package in package_versions.keys()

                if query.query_metadata.affected_package_version:
                    assert (
                        query.query_metadata.affected_package_version
                        in package_versions[query.query_metadata.affected_package]
                        or "*"
                        in package_versions[query.query_metadata.affected_package]
                    )

            # verify nvd data by querying all nvd vuln ids and build expected data for each one
            assert len(vuln["nvd_data"]) > 0
            related_ids = {related_vuln["id"] for related_vuln in vuln["nvd_data"]}
            nvd_resp = policy_engine_api.query_vulnerabilities.get_vulnerabilities(
                related_ids, namespace=nvd_namespace
            )

            expected_nvd_data = []
            for nvd_record in nvd_resp.body:
                expected_nvd_data += nvd_record["nvd_data"]

            actual_nvd_data = sorted(vuln["nvd_data"], key=sort_nvd_data)
            expected_nvd_data.sort(key=sort_nvd_data)

            assert expected_nvd_data == actual_nvd_data

        if is_legacy_test:
            expected = expected_content(query.expected_output_file)
            if expected:
                expected = _ordered(expected)
                # expected.sort(key=lambda x: (x.get("id"), x.get("namespace")))
            results = vulnerabilities_resp.body
            if results:
                results = _ordered(results)
                # results.sort(key=lambda x: (x.get("id"), x.get("namespace")))
            assert results == expected

    @pytest.mark.skipif(
        not is_legacy_provider(),
        reason="affected package version not currently supported by grype",
    )
    @pytest.mark.parametrize(
        "query",
        [
            VulnerabilityQuery(
                ["CVE-2017-18018"],
                "expected_empty_incorrect_version",
                VulnerabilityQueryMetadata(
                    affected_package="coreutils",
                    namespace="nvdv2:cves",
                    affected_package_version="10",
                ),
            ),
        ],
    )
    def test_expected_empty_result(self, query):
        vulnerabilities_resp = (
            policy_engine_api.query_vulnerabilities.get_vulnerabilities(
                query.id, **asdict(query.query_metadata)
            )
        )

        assert vulnerabilities_resp == http_utils.APIResponse(200)
        assert len(vulnerabilities_resp.body) == 0
